Skip to content

[DSV4] Add BF16 and MXFP8 A2A support for flashinfer a2a one sided#40960

Merged
ywang96 merged 6 commits intovllm-project:mainfrom
zyongye:nvlink_one_sided_bf16_support_upstream
Apr 30, 2026
Merged

[DSV4] Add BF16 and MXFP8 A2A support for flashinfer a2a one sided#40960
ywang96 merged 6 commits intovllm-project:mainfrom
zyongye:nvlink_one_sided_bf16_support_upstream

Conversation

@zyongye
Copy link
Copy Markdown
Member

@zyongye zyongye commented Apr 27, 2026

Purpose

Originally Flashinfer one sided a2a only supports nvfp4 dispatch. Add BF16 and MXFP8 dispatch.

Test Plan

gsm8k on V4-Flash

Test Result

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9583|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9591|±  |0.0055|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 27, 2026

Documentation preview: https://vllm--40960.org.readthedocs.build/en/40960/

@mergify mergify Bot added documentation Improvements or additions to documentation nvidia labels Apr 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for bf16 dispatch and deferred input quantization within the flashinfer_nvlink_one_sided MoE kernel, specifically to accommodate expert types like trtllm_mxfp4. The implementation updates the MoeAlltoAll workspace to support dynamic payload sizes and modifies the preparation logic to handle cases where quantization is performed post-dispatch. Feedback was provided regarding a potential AttributeError when quant_config is None and the necessity of updating existing validation checks to ensure the new logic is reachable.

Comment on lines +243 to +254
if defer_input_quant or quant_config.quant_dtype is None:
# Experts (e.g. trtllm_mxfp4 with mxfp8 activations) quantize
# post-dispatch; ship bf16 tokens with no per-token scale payload.
dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 2, False
elif quant_config.quant_dtype == "nvfp4":
dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 0, True
else:
raise NotImplementedError(
"flashinfer_nvlink_one_sided dispatch only supports nvfp4, "
"bf16, and defer_input_quant paths today; got "
f"quant_dtype={quant_config.quant_dtype!r}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic is currently unreachable for non-nvfp4 models (such as mxfp4) because of the ValueError check at lines 232-239 (outside this diff hunk). To support BF16 and deferred quantization for other formats, that validation block should be removed or updated. Additionally, the code should safely handle cases where quant_config is None (e.g., for unquantized models) to avoid an AttributeError when accessing quant_config.quant_dtype.

        quant_dtype = quant_config.quant_dtype if quant_config is not None else None
        if defer_input_quant or quant_dtype is None:
            # Experts (e.g. trtllm_mxfp4 with mxfp8 activations) quantize
            # post-dispatch; ship bf16 tokens with no per-token scale payload.
            dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 2, False
        elif quant_dtype == "nvfp4":
            dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 0, True
        else:
            raise NotImplementedError(
                "flashinfer_nvlink_one_sided dispatch only supports nvfp4, "
                "bf16, and defer_input_quant paths today; got "
                f"quant_dtype={quant_dtype!r}"
            )

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to change this condition as well otherwise, the branch would early exit at here for bf16.

@zyongye zyongye mentioned this pull request Apr 28, 2026
32 tasks
@zyongye zyongye force-pushed the nvlink_one_sided_bf16_support_upstream branch from e86f2a3 to b07b85a Compare April 28, 2026 22:19
@zyongye zyongye changed the title Add BF16 A2A support for flashinfer a2a one sided [DSV4] Add BF16 and MXFP8 A2A support for flashinfer a2a one sided Apr 28, 2026
@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 28, 2026
Copy link
Copy Markdown
Contributor

@hjjq hjjq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!!

@bobboli
Copy link
Copy Markdown

bobboli commented Apr 29, 2026

LGTM overall. I have confirmed that the gen step time is ~40ms for AG/RS and ~35ms for NVLinkOneSided A2A under the following config:

  • Model: DeepSeek-V4-Flash
  • ISL / OSL: 8192 / 1024
  • Local batch size: 512 per DP rank
  • DP8EP8, 8x B200
  • kv_cache_dtype: fp8

@bobboli
Copy link
Copy Markdown

bobboli commented Apr 29, 2026

A small missing piece is payload_in_workspace for combine. It means that the MoE module could directly output to the workspace of A2A, so that there is no need for extra copying of the tokens from the local Tensor onto workspace.
Related code in TRTLLM is here.
It requires that the MoE OP could be provided an pre-allocated tensor as output, which is intrusive.

@zyongye
Copy link
Copy Markdown
Member Author

zyongye commented Apr 29, 2026

A small missing piece is payload_in_workspace for combine. It means that the MoE module could directly output to the workspace of A2A, so that there is no need for extra copying of the tokens from the local Tensor onto workspace. Related code in TRTLLM is here. It requires that the MoE OP could be provided an pre-allocated tensor as output, which is intrusive.

Thanks. We do have pre-allocated workspace buffer but now we are just manually copy it. I will take a look at trtllm code!

@bobboli
Copy link
Copy Markdown

bobboli commented Apr 29, 2026

Thanks. We do have pre-allocated workspace buffer but now we are just manually copy it. I will take a look at trtllm code!

Yes, the point is that a portion of workspace is viewed as a Tensor, and passed into the MoE OP as output Tensor, so that the MoE OP directly outputs onto the workspace.

@zyongye zyongye force-pushed the nvlink_one_sided_bf16_support_upstream branch 2 times, most recently from 2d72303 to 32df1aa Compare April 30, 2026 16:33
zyongye and others added 6 commits April 30, 2026 16:35
…ll2all

The one-sided MoeAlltoAll dispatch workspace was hardcoded for nvfp4
hidden states + fp8 scales, so any other activation dtype overran the
buffer. Parameterize the workspace sizing by bytes-per-elem and whether
an fp8 scale payload is present, then route non-nvfp4 quant configs to
a bf16 dispatch (2 B/elem, no scale) via a new defer_input_quant hint.

trtllm_mxfp4 experts already advertise expects_unquantized_inputs=True
(they call mxfp8_quantize internally). Wire make_mxfp4_moe_kernel to
pass that signal into maybe_make_prepare_finalize, and have the one-
sided prepare() honor the per-call defer_input_quant flag by shipping
a1 as bf16 with no scale payload. Two-sided already handled this.

NOTE: the flashinfer moe_a2a_dispatch C++ kernel only templates top_k
in {1, 2, 4, 8}; models with other top_k (e.g. DeepSeek-V4 top_k=6)
must use flashinfer_nvlink_two_sided instead.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
)

Padded rows in the [ep_size, max_num_tokens, ...] workspace retain
stale topk_ids from prior dispatch calls (the workspace is zeroed only
once at init). Those stale ids cause the downstream trtllm_fp4 grouped
GEMM to do phantom work for random local experts every layer, which
(a) inflates expert GEMM time and (b) creates the cross-rank skew that
the combine kernel then has to wait on.

Setting `invalid_token_expert_id` to `num_experts` (one past the valid
expert range) makes the flashinfer worker overwrite all top_k topk_ids
slots of padded rows with that sentinel (moeA2ASanitizeExpertIdsKernel
in moeAlltoAllKernels.cu); the trtllm grouped GEMM then sees those
rows as routed to no local expert (out of [local_expert_offset,
local_expert_offset + local_num_experts)) and skips them.

Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
When prepare-side mxfp8_quantize pads K to mx_alignment (e.g. gpt-oss
hidden=2880 -> 3072 with align=256), pre-PR's torch.empty_like(hidden_states)
naturally produced an unpadded output because hidden_states was the original
bf16 input. With prepare-side quantize, hidden_states entering apply() is the
padded fp8 tensor, so allocating output by self.hidden_dim (which is the
post-roundup padded value from maybe_roundup_sizes) propagates padding into
lm_head. Use moe_config.hidden_dim_unpadded so trtllm internally truncates
back to the original hidden, matching pre-PR behavior. Apply the same fix to
the modular workspace_shapes for non-aligned hiddens with EP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@zyongye zyongye force-pushed the nvlink_one_sided_bf16_support_upstream branch from 32df1aa to 96d415a Compare April 30, 2026 16:35
@ywang96 ywang96 merged commit b4806c8 into vllm-project:main Apr 30, 2026
77 of 81 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Apr 30, 2026
@ywang96 ywang96 added this to the v0.20.1 milestone Apr 30, 2026
@khluu khluu mentioned this pull request May 1, 2026
khluu pushed a commit that referenced this pull request May 1, 2026
…40960)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
(cherry picked from commit b4806c8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector mistral Related to Mistral models new-model Requests to new models nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding tool-calling v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants